"""Utility functions for testing the wrappers."""

from __future__ import annotations

import gymnasium as gym
from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
from tests.testing_env import GenericTestEnv


SEED = 42
ENV_ID = "CartPole-v1"
DISCRETE_ACTION = 0
NUM_ENVS = 3
NUM_STEPS = 20


def record_obs_reset(self: gym.Env, seed=None, options: dict = None):
    """Records and uses an observation passed through options."""
    return options["obs"], {"obs": options["obs"]}


def record_random_obs_reset(self: gym.Env, seed=None, options=None):
    """Records random observation generated by the environment."""
    obs = self.observation_space.sample()
    return obs, {"obs": obs}


def record_action_step(self: gym.Env, action):
    """Records the actions passed to the environment."""
    return 0, 0, False, False, {"action": action}


def record_random_obs_step(self: gym.Env, action):
    """Records the observation generated by the environment."""
    obs = self.observation_space.sample()
    return obs, 0, False, False, {"obs": obs}


def record_action_as_obs_step(self: gym.Env, action):
    """Uses the action as the observation."""
    return action, 0, False, False, {"obs": action}


def record_action_as_record_step(self: gym.Env, action):
    """Uses the action as the reward."""
    return 0, action, False, False, {"reward": action}


def check_obs(
    env: gym.Env,
    wrapped_env: gym.Wrapper,
    transformed_obs,
    original_obs,
    strict: bool = True,
):
    """Checks that the original and transformed observations using the environment and wrapped environment.

    Args:
        env: The base environment
        wrapped_env: The wrapped environment
        transformed_obs: The transformed observation by the wrapped environment
        original_obs: The original observation by the base environment.
        strict: If to check that the observations aren't contained in the other environment.
    """
    assert (
        transformed_obs in wrapped_env.observation_space
    ), f"{transformed_obs}, {wrapped_env.observation_space}"
    assert (
        original_obs in env.observation_space
    ), f"{original_obs}, {env.observation_space}"

    if strict:
        assert (
            transformed_obs not in env.observation_space
        ), f"{transformed_obs}, {env.observation_space}"
        assert (
            original_obs not in wrapped_env.observation_space
        ), f"{original_obs}, {wrapped_env.observation_space}"


TESTING_OBS_ENVS = [GenericTestEnv(observation_space=space) for space in TESTING_SPACES]
TESTING_OBS_ENVS_IDS = TESTING_SPACES_IDS

TESTING_ACTION_ENVS = [GenericTestEnv(action_space=space) for space in TESTING_SPACES]
TESTING_ACTION_ENVS_IDS = TESTING_SPACES_IDS


def has_wrapper(wrapped_env: gym.Env, wrapper_type: type[gym.Wrapper]) -> bool:
    """Checks if the wrapper type is within the wrapped environment stack."""
    while isinstance(wrapped_env, gym.Wrapper):
        if isinstance(wrapped_env, wrapper_type):
            return True
        wrapped_env = wrapped_env.env
    return False
